gusucode.com > 支持向量机工具箱 - LIBSVM OSU_SVM LS_SVM源码程序 > 支持向量机工具箱 - LIBSVM OSU_SVM LS_SVM\stprtool\svm\oaosvmlight.m

    function [model] = oaosvmlight( data, labels, ker, arg, C, eps, verb)
% OAOSVMLIGHT One-Agains-One multi-class decomposition solved by SVM^Light.
% 
% [model] = oaosvmlight( data, labels, ker, arg, C, eps, verb)
%
% The programs 'svm_learn' and 'svm_classify' must be in the path.
%
% Inputs:
%  data [dim x N] training patterns
%  labels [1 x N] labels of training patterns
%  ker [string] kernel, see 'help kernel'.
%  arg [...] argument of given kernel, see 'help kernel'.
%  C [real] trade-off between margin and training error.
%  eps [real] KT stopping condiiton.
%  verb [int] if 1 then progress info is displayed. 
% 
% Output:
%  model [struct] contains found O-A-O SVM classifier.
% 

%  Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac
%  (c) Czech Technical University Prague, http://cmp.felk.cvut.cz
%  Written Vojtech Franc (diploma thesis) 02.11.1999, 13.4.2000
%
%  Modifications
%   3-Jun-2002, V.Franc

[dim,num_data ] = size(data);
num_classes = max(labels);

model.name = 'One-Against-One, SVM classifier';

model.num_classes = num_classes;
model.num_rules = num_classes*(num_classes-1)/2;
model.rule = cell(model.num_rules);

model.SVM.C = C;
model.SVM.kernel = ker;
model.SVM.arg = arg;

model.trn_data = data;
model.trn_labels = labels;
model.kercnt=0;

trn_errors = zeros(1, model.num_rules);
sv=zeros(1,num_data);

cnt=0;

%--------------------------------
switch ker
  case 'linear'
    ker='-t 0';
  case 'rbf'
    ker=['-t 2 -g ' num2str(1/(2*arg^2))]; 
  case 'poly' 
    ker=['-t 1 -r 1 -s 1 -d ' num2str(arg)];  
end

command=['svm_learn ' ...
         '-c ' num2str(C) ' '...
         ker ' '...
         '-v 1' ' ' ...
         '-m 1' ' ' ...
         '-e ' num2str(eps) ' '...
         '-a tmp_alpha.txt tmp_examples.txt tmp_model.txt > tmp_verb.txt'];
   

% builds num_classes*(num_classes-1)/2 1-a-1 classfication rules
model.btime=cputime;
for class1=1:num_classes-1,
  for class2=class1+1:num_classes,
  
    cnt=cnt+1;
    
    if verb ==1,
      fprintf(1, 'building rule %d-%d (%d of %d) ...', ...
        class1,class2, cnt, model.num_rules );
    end

    model.rule{cnt}.class1 = class1;
    model.rule{cnt}.class2 = class2;
    
    % takes data from class1 and class2
    model.rule{cnt}.data_inx = find(labels==class1 | labels==class2);
    model.rule{cnt}.labels = labels(model.rule{cnt}.data_inx);
    model.rule{cnt}.labels(find(model.rule{cnt}.labels==class1))=1;
    model.rule{cnt}.labels(find(model.rule{cnt}.labels==class2))=2;
  
    xi2svmlight(data(:,model.rule{cnt}.data_inx),model.rule{cnt}.labels,...
        'tmp_examples.txt');
    
    % call SVM_LIGHT
%    evalc(command);
    [a,b]=unix(command);
    
    [lines]=textread('tmp_model.txt','%s');
    for i=1:size(lines,1)
      if strcmpi( lines(i), 'threshold' )==1,
        bias=-str2num( lines{i-2});
        break;
      end
    end
    
    Alpha=textread('tmp_alpha.txt','%f');
    Alpha=Alpha(:)'.*itosgn(model.rule{cnt}.labels);

    [lines]=textread('tmp_verb.txt','%s');
    for i=1:size(lines,1)
      if strcmpi( lines{i}, 'misclassified,' ),
        trnerr=str2num( lines{i-1}(2:end));
        trnerr=trnerr/length(Alpha);
      end
      if strcmpi( lines(i), 'vector:' ) & strcmpi( lines(i-1), 'weight' )==1,
        margin=1/str2num( lines{i+1}(5:end));
      end
      if strcmpi( lines(i), 'SV:' )==1,
        nsv=str2num( lines{i+1});
      end
      if strcmpi( lines(i), 'evaluations:' )==1,
        kercnt=str2num( lines{i+1});
      end
    end
    
    model.rule{cnt}.Alpha = Alpha;
    model.rule{cnt}.bias = bias;
    model.rule{cnt}.kercnt = kercnt;
    model.rule{cnt}.margin = margin;
    model.rule{cnt}.nsv = length(find(Alpha~=0));
    model.rule{cnt}.trnerr = trnerr;
    model.kercnt = model.kercnt + kercnt;

    trn_errors(cnt) = trnerr;
    
    sv(model.rule{cnt}.data_inx(find(Alpha ~=0)))=1;
    
    if verb ==1,
      fprintf(1,'done\n');
    end
    
  end
end

model.btime=cputime-model.btime;
model.trnerr = mean( trn_errors);
model.nsv = length(find(sv ~=0));

return;

%EOF